import jax
import jax.numpy as jnp
import flax.linen as nn 
import jax.random as random

def runge_kutta_step(f, x, t, h):
    k1 = f(x, t)
    k2 = f(x + h/2*k1, t + h/2)
    k3 = f(x + h/2*k2, t + h/2)
    k4 = f(x + h*k3, t + h)
    return (k1 + 2*k2 + 2*k3 + k4)/6

def euler_step(f, x, t, h):
    return f(x, t)

def midpoint_step(f, x, t, h):
    k1 = f(x, t)
    k2 = f(x + h/2*k1, t + h/2)
    return k2

class VectorField(nn.Module):
    x_dim: int
    hidden_nf: int
    n_layers: int

    def setup(self):
        self.layers = [nn.Dense(features=self.hidden_nf) for _ in range(self.n_layers)]
        self.out_layer = nn.Dense(features=self.x_dim,kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init())

    def __call__(self, x, t):
        ts = jnp.ones((x.shape[0],1))*t
        for layer in self.layers:
            x = jnp.concatenate([x,ts],axis=1)
            x = layer(x)
            x = jax.nn.swish(x)
        x = self.out_layer(x)
        return x

class CNF(nn.Module):
    x_dim: int
    hidden_nf: int
    n_layers: int
    n_steps: int
    seed: int

    def setup(self):
        self.vector_field = VectorField(self.x_dim,self.hidden_nf,self.n_layers)
        self.key = jax.random.PRNGKey(self.seed)

    def __call__(self, x):
        vector_fields = []
        e_vjps = []
        key = self.key
        jac_trace_est = jnp.zeros(x.shape[0])
        for i in range(self.n_steps):
            t = i/self.n_steps
            vf, f_vjp = jax.vjp(lambda x: runge_kutta_step(self.vector_field, x, t, 1/self.n_steps), x)
            vector_fields.append(vf) 
            x = x + vf/self.n_steps
            key = jax.random.split(key)[0]
            e = jax.random.normal(key, x.shape)
            e_vjp = f_vjp(e)[0]
            jac_trace_est = jac_trace_est + jnp.einsum("Ni,Ni->N",e_vjp,e)/self.n_steps
            e_vjps.append(e_vjp)
        vector_fields = jnp.stack(vector_fields)
        e_vjps = jnp.stack(e_vjps)
        return x,vector_fields,e_vjps,jac_trace_est
    
    def forward(self, x):
        for i in range(self.n_steps):
            t = i/self.n_steps
            vf = runge_kutta_step(self.vector_field, x, t, 1/self.n_steps)
            x = x + vf/self.n_steps
        return x
    
    def inverse(self, x):
        vector_fields = []
        for i in range(self.n_steps):
            t = 1.0 - i/self.n_steps
            vf = runge_kutta_step(self.vector_field, x, t, -1/self.n_steps)
            vector_fields.append(vf) 
            x = x - vf/self.n_steps
        vector_fields = jnp.stack(vector_fields)
        return x,vector_fields
    
class FMCNF(nn.Module):
    vector_field: nn.Module
    n_steps: int
    seed: int
    scheme: str

    def setup(self):
        self.key = random.PRNGKey(self.seed+100)

    def __call__(self, x):
        for i in range(self.n_steps):
            t = i/self.n_steps
            if self.scheme == "euler":
                vf = euler_step(self.vector_field, x, t, 1/self.n_steps)
            elif self.scheme == "rk4":
                vf = runge_kutta_step(self.vector_field, x, t, 1/self.n_steps)
            elif self.scheme == "midpoint":
                vf = midpoint_step(self.vector_field, x, t, 1/self.n_steps)
            x = x + vf/self.n_steps
        return x
    
    def simulate(self, n_samples, n_steps):
        key = self.key
        x = random.normal(key, (n_samples,self.vector_field.x_dim))
        traj = [x]
        for i in range(n_steps):
            t = i/n_steps
            if self.scheme == "euler":
                vf = euler_step(self.vector_field, x, t, 1/n_steps)
            elif self.scheme == "rk4":
                vf = runge_kutta_step(self.vector_field, x, t, 1/n_steps)
            elif self.scheme == "midpoint":
                vf = midpoint_step(self.vector_field, x, t, 1/n_steps)
            x = x + vf/n_steps
            traj.append(x)
        return jnp.stack(traj)
    
    def log_prob(self, x,jac_trace_init):
        key = self.key
        jac_trace_est = jac_trace_init
        for i in range(self.n_steps):
            t = 1.0-i/self.n_steps
            key = random.split(key)[1]
            e = random.normal(key,shape=x.shape)
            key = random.split(key)[1]
            e_ = random.normal(key,shape=x.shape)
            vf, jac_e = jax.jvp(lambda x: runge_kutta_step(self.vector_field, x, t, 1/self.n_steps), (x,), (e,))
            x = x - vf/self.n_steps
            jac_trace_est = jac_trace_est + jnp.einsum("Ni,Ni->N",jac_e,e)/self.n_steps
        log_p0 = 1/jnp.sqrt(2*jnp.pi) * jnp.exp(-0.5 * jnp.sum(x**2,axis=-1)) 
        return log_p0 - jac_trace_est
        
    
class VAE(nn.Module):
    
    x_dim: int
    n_encoder_layers: int
    n_decoder_layers: int
    hidden_nf: int
    latent_nf: int

    def setup(self):

        self.encoder = [nn.Dense(features=self.hidden_nf) for _ in range(self.n_encoder_layers)]
        self.mu_z = nn.Dense(features=self.latent_nf)
        self.log_var_z = nn.Dense(features=self.latent_nf)
        self.decoder = [nn.Dense(features=self.hidden_nf) for _ in range(self.n_decoder_layers)]
        self.mu_x = nn.Dense(features=self.x_dim)
        self.log_var_x = nn.Dense(features=self.x_dim)

    def encode(self, x):
        for layer in self.encoder:
            x = layer(x)
            x = jax.nn.relu(x)
        mu_z = self.mu_z(x)
        log_var_z = self.log_var_z(x)
        return mu_z, log_var_z
    
    def decode(self, z):
        for layer in self.decoder:
            z = layer(z)
            z = jax.nn.relu(z)
        mu_x = self.mu_x(z)
        log_var_x = self.log_var_x(z)
        return mu_x, log_var_x
    
    def sample(self, mu, log_var, key):
        std = jnp.exp(log_var)
        eps = random.normal(key, mu.shape)
        return mu + eps * std
    
    def __call__(self, x, key):
        mu_z, log_var_z = self.encode(x)
        z = self.sample(mu_z, log_var_z, key)
        mu_x, log_var_x = self.decode(z)
        return mu_x,log_var_x, mu_z, log_var_z